{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0c6cea67-3c4e-4798-ba5f-9d7191621dbb",
   "metadata": {},
   "source": [
    "# SciBERT for Single-Label Classification\n",
    "\n",
    "[![image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/center-for-threat-informed-defense/tram/blob/main/user_notebooks/fine_tune_single_label.ipynb)\n",
    "\n",
    "This notebook allows one to continue fine-tuning our provided SciBERT-for-singlelabel-sequence-classification on custom data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b52f6f7-1202-4210-aa83-693194fba799",
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir scibert_single_label_model\n",
    "!wget https://ctidtram.blob.core.windows.net/tram-models/single-label-202308303/config.json -O scibert_single_label_model/config.json\n",
    "!wget https://ctidtram.blob.core.windows.net/tram-models/single-label-202308303/pytorch_model.bin -O scibert_single_label_model/pytorch_model.bin\n",
    "!pip install torch transformers pandas"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f55728b-50a9-4073-9570-29a58b8f972d",
   "metadata": {},
   "source": [
    "This cell instantiates the label encoder. Do not modify this cell, as the classes (ie, ATT&CK techniques) and their order must match those the model expects."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "75c82006-f3b4-409a-9bde-035eb14070b5",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[array(['T1003.001', 'T1005', 'T1012', 'T1016', 'T1021.001', 'T1027',\n",
       "        'T1033', 'T1036.005', 'T1041', 'T1047', 'T1053.005', 'T1055',\n",
       "        'T1056.001', 'T1057', 'T1059.003', 'T1068', 'T1070.004',\n",
       "        'T1071.001', 'T1072', 'T1074.001', 'T1078', 'T1082', 'T1083',\n",
       "        'T1090', 'T1095', 'T1105', 'T1106', 'T1110', 'T1112', 'T1113',\n",
       "        'T1140', 'T1190', 'T1204.002', 'T1210', 'T1218.011', 'T1219',\n",
       "        'T1484.001', 'T1518.001', 'T1543.003', 'T1547.001', 'T1548.002',\n",
       "        'T1552.001', 'T1557.001', 'T1562.001', 'T1564.001', 'T1566.001',\n",
       "        'T1569.002', 'T1570', 'T1573.001', 'T1574.002'], dtype=object)]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.preprocessing import OneHotEncoder as OHE\n",
    "\n",
    "CLASSES = [\n",
    "   'T1003.001', 'T1005', 'T1012', 'T1016', 'T1021.001', 'T1027',\n",
    "   'T1033', 'T1036.005', 'T1041', 'T1047', 'T1053.005', 'T1055',\n",
    "   'T1056.001', 'T1057', 'T1059.003', 'T1068', 'T1070.004',\n",
    "   'T1071.001', 'T1072', 'T1074.001', 'T1078', 'T1082', 'T1083',\n",
    "   'T1090', 'T1095', 'T1105', 'T1106', 'T1110', 'T1112', 'T1113',\n",
    "   'T1140', 'T1190', 'T1204.002', 'T1210', 'T1218.011', 'T1219',\n",
    "   'T1484.001', 'T1518.001', 'T1543.003', 'T1547.001', 'T1548.002',\n",
    "   'T1552.001', 'T1557.001', 'T1562.001', 'T1564.001', 'T1566.001',\n",
    "   'T1569.002', 'T1570', 'T1573.001', 'T1574.002'\n",
    "]\n",
    "\n",
    "encoder = OHE(sparse_output=False)\n",
    "encoder.fit([[c] for c in CLASSES])\n",
    "\n",
    "encoder.categories_"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "296b611b-67ae-4201-a559-2f5f56ab2723",
   "metadata": {},
   "source": [
    "This cell is for loading the training data. You will need to modify this cell to load your data. Ensure that by the end of this cell, a DataFrame has been assigned to the variable `data` that has a `text` column containing the segments, and a `label` column containing individual strings, where those strings are an ATT&CK IDs that this model can classify. It does not matter how the DataFrame is indexed or what other columns with other names, if any, it has.\n",
    "\n",
    "For demonstration purposes, we will use the same single-label data that was produced during this TRAM effort, even though the model was trained on this data already. This cell is only present to show the expected format of the `data` DataFrame, and is not intended to be run as shown."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "2622429e-512d-443e-895b-eeec53afcecc",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>text</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>This file extracts credentials from LSASS simi...</td>\n",
       "      <td>T1003.001</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>It calls OpenProcess on lsass.exe with access ...</td>\n",
       "      <td>T1003.001</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>It spreads to Microsoft Windows machines using...</td>\n",
       "      <td>T1210</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>SMB exploitation via EternalBlue</td>\n",
       "      <td>T1210</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>SMBv1 Exploitation via EternalBlue</td>\n",
       "      <td>T1210</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>495</th>\n",
       "      <td>The unpacked sample is approximately 540 KB</td>\n",
       "      <td>T1027</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>496</th>\n",
       "      <td>decompress data blobs</td>\n",
       "      <td>T1140</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>497</th>\n",
       "      <td>decompress them</td>\n",
       "      <td>T1140</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>498</th>\n",
       "      <td>The decompression function</td>\n",
       "      <td>T1140</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>499</th>\n",
       "      <td>Decrypted strings</td>\n",
       "      <td>T1140</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>500 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                  text      label\n",
       "0    This file extracts credentials from LSASS simi...  T1003.001\n",
       "1    It calls OpenProcess on lsass.exe with access ...  T1003.001\n",
       "2    It spreads to Microsoft Windows machines using...      T1210\n",
       "3                     SMB exploitation via EternalBlue      T1210\n",
       "4                   SMBv1 Exploitation via EternalBlue      T1210\n",
       "..                                                 ...        ...\n",
       "495        The unpacked sample is approximately 540 KB      T1027\n",
       "496                              decompress data blobs      T1140\n",
       "497                                    decompress them      T1140\n",
       "498                         The decompression function      T1140\n",
       "499                                  Decrypted strings      T1140\n",
       "\n",
       "[500 rows x 2 columns]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "data = pd.read_json('../single_label.json').drop(columns='doc_title').head(500)\n",
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "7136c953-fb1a-4e8a-887f-1ea4ec92add4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import transformers\n",
    "import torch\n",
    "\n",
    "cuda = torch.device('cuda')\n",
    "\n",
    "tokenizer = transformers.BertTokenizer.from_pretrained(\"allenai/scibert_scivocab_uncased\", max_length=512)\n",
    "bert = transformers.BertForSequenceClassification.from_pretrained('scibert_single_label_model').to(cuda).train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "e0e7574e-60c1-485e-96b1-a161babc69ef",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "train, test = train_test_split(data, test_size=0.2, shuffle=True)\n",
    "\n",
    "def _load_data(x, y, batch_size=10):\n",
    "    x_len, y_len = x.shape[0], y.shape[0]\n",
    "    assert x_len == y_len\n",
    "    for i in range(0, x_len, batch_size):\n",
    "        slc = slice(i, i + batch_size)\n",
    "        yield x[slc].to(cuda), y[slc].to(cuda)\n",
    "\n",
    "def _tokenize(instances: list[str]):\n",
    "    return tokenizer(instances, return_tensors='pt', padding='max_length', truncation=True, max_length=512).input_ids\n",
    "\n",
    "def _encode_labels(labels):\n",
    "    \"\"\":labels: should be the `labels` column (a Series) of the DataFrame\"\"\"\n",
    "    return torch.Tensor(encoder.transform(labels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "8ad37bdc-842f-46d7-b929-d0db6eae45e0",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[  102,  4546,   217,  ...,     0,     0,     0],\n",
       "        [  102,   106,  2289,  ...,     0,     0,     0],\n",
       "        [  102,   111,  1384,  ...,     0,     0,     0],\n",
       "        ...,\n",
       "        [  102,  9683,   972,  ...,     0,     0,     0],\n",
       "        [  102,   111, 24870,  ...,     0,     0,     0],\n",
       "        [  102, 14397,   111,  ...,     0,     0,     0]])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_train = _tokenize(train['text'].tolist())\n",
    "x_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "58cb43bb-ccf8-4ecc-849e-be8a52c26585",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/projects/TRAM2023/venv/lib/python3.9/site-packages/sklearn/base.py:432: UserWarning: X has feature names, but OneHotEncoder was fitted without feature names\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        [1., 0., 0.,  ..., 0., 0., 0.],\n",
       "        ...,\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.]])"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_train = _encode_labels(train[['label']])\n",
    "y_train"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "560960f4-e966-439a-a9c6-e65c77b19bb7",
   "metadata": {},
   "source": [
    "This array may appear to be empty, but taking the sum shows that there is one `1` per row."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "bd10b10a-4dec-42d6-b3b9-1e1ef5019d3d",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(400.)"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_train.sum()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3504844-b451-490e-8bad-c9d0f975999e",
   "metadata": {},
   "source": [
    "This cell contains the training loop. You may change the `NUM_EPOCHS` value to any integer you would like."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "cce77f9f-76d2-4e4c-90b4-77106f0154cb",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "40it [00:23,  1.68it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1 loss: 0.009985628997674212\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "40it [00:20,  1.92it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 2 loss: 0.005113609199179336\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "40it [00:20,  1.91it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 3 loss: 0.0038467945356387644\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "NUM_EPOCHS = 3\n",
    "\n",
    "from statistics import mean\n",
    "\n",
    "from tqdm import tqdm\n",
    "from torch.optim import AdamW\n",
    "\n",
    "optim = AdamW(bert.parameters(), lr=2e-5, eps=1e-8)\n",
    "\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "    epoch_losses = []\n",
    "    for x, y in tqdm(_load_data(x_train, y_train, batch_size=10)):\n",
    "        bert.zero_grad()\n",
    "        out = bert(x, attention_mask=x.ne(tokenizer.pad_token_id).to(int), labels=y)\n",
    "        epoch_losses.append(out.loss.item())\n",
    "        out.loss.backward()\n",
    "        optim.step()\n",
    "    print(f\"epoch {epoch + 1} loss: {mean(epoch_losses)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9af32bbd-2337-4d2d-bf71-0dacb40afb43",
   "metadata": {},
   "source": [
    "If the loss from the last iteration was not to your liking, do not re-run the previous cell. Uncomment the following cell and run it for however many additional epochs you would like."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa5a01a0-eb3b-4564-bd4a-a43f819b9aeb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# NUM_EXTRA_EPOCHS = 1\n",
    "# for epoch in range(NUM_EXTRA_EPOCHS):\n",
    "#     epoch_losses = []\n",
    "#     for x, y in tqdm(_load_data(x_train, y_train, batch_size=10)):\n",
    "#         bert.zero_grad()\n",
    "#         out = bert(x, attention_mask=x.ne(tokenizer.pad_token_id).to(int), labels=y)\n",
    "#         epoch_losses.append(out.loss.item())\n",
    "#         out.loss.backward()\n",
    "#         optim.step()\n",
    "#     print(f\"epoch {epoch + 1} loss: {mean(epoch_losses)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d7311a4-6ac8-4397-8497-369865a7c924",
   "metadata": {},
   "source": [
    "The next cells evaluate the performance after the additional fine-tuning. The performance scores on the example data will be high, as the model has already been trained on most of these instances."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "9eff1660-da8b-4d9d-bdcd-53d01d5a05a4",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>P</th>\n",
       "      <th>R</th>\n",
       "      <th>F1</th>\n",
       "      <th>#</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>T1003.001</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1005</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>4.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1016</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1021.001</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1027</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>12.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1033</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1041</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1047</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1053.005</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>5.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1055</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>0.666667</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1056.001</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1057</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1059.003</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>6.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1070.004</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>4.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1071.001</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1078</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>3.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1090</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1095</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1105</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.800000</td>\n",
       "      <td>0.888889</td>\n",
       "      <td>5.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1106</th>\n",
       "      <td>0.900000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.947368</td>\n",
       "      <td>9.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1112</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>4.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1140</th>\n",
       "      <td>0.833333</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.909091</td>\n",
       "      <td>5.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1204.002</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1210</th>\n",
       "      <td>0.500000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.666667</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1218.011</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1219</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1484.001</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1543.003</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1547.001</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>4.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1562.001</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.833333</td>\n",
       "      <td>0.909091</td>\n",
       "      <td>6.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1566.001</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1569.002</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1570</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1574.002</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>(micro)</th>\n",
       "      <td>0.970000</td>\n",
       "      <td>0.970000</td>\n",
       "      <td>0.970000</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>(macro)</th>\n",
       "      <td>0.977451</td>\n",
       "      <td>0.974510</td>\n",
       "      <td>0.970229</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                  P         R        F1     #\n",
       "T1003.001  1.000000  1.000000  1.000000   2.0\n",
       "T1005      1.000000  1.000000  1.000000   4.0\n",
       "T1016      1.000000  1.000000  1.000000   1.0\n",
       "T1021.001  1.000000  1.000000  1.000000   2.0\n",
       "T1027      1.000000  1.000000  1.000000  12.0\n",
       "T1033      1.000000  1.000000  1.000000   2.0\n",
       "T1041      1.000000  1.000000  1.000000   2.0\n",
       "T1047      1.000000  1.000000  1.000000   2.0\n",
       "T1053.005  1.000000  1.000000  1.000000   5.0\n",
       "T1055      1.000000  0.500000  0.666667   2.0\n",
       "T1056.001  1.000000  1.000000  1.000000   1.0\n",
       "T1057      1.000000  1.000000  1.000000   1.0\n",
       "T1059.003  1.000000  1.000000  1.000000   6.0\n",
       "T1070.004  1.000000  1.000000  1.000000   4.0\n",
       "T1071.001  1.000000  1.000000  1.000000   2.0\n",
       "T1078      1.000000  1.000000  1.000000   3.0\n",
       "T1090      1.000000  1.000000  1.000000   1.0\n",
       "T1095      1.000000  1.000000  1.000000   1.0\n",
       "T1105      1.000000  0.800000  0.888889   5.0\n",
       "T1106      0.900000  1.000000  0.947368   9.0\n",
       "T1112      1.000000  1.000000  1.000000   4.0\n",
       "T1140      0.833333  1.000000  0.909091   5.0\n",
       "T1204.002  1.000000  1.000000  1.000000   1.0\n",
       "T1210      0.500000  1.000000  0.666667   1.0\n",
       "T1218.011  1.000000  1.000000  1.000000   1.0\n",
       "T1219      1.000000  1.000000  1.000000   1.0\n",
       "T1484.001  1.000000  1.000000  1.000000   1.0\n",
       "T1543.003  1.000000  1.000000  1.000000   2.0\n",
       "T1547.001  1.000000  1.000000  1.000000   4.0\n",
       "T1562.001  1.000000  0.833333  0.909091   6.0\n",
       "T1566.001  1.000000  1.000000  1.000000   2.0\n",
       "T1569.002  1.000000  1.000000  1.000000   2.0\n",
       "T1570      1.000000  1.000000  1.000000   1.0\n",
       "T1574.002  1.000000  1.000000  1.000000   2.0\n",
       "(micro)    0.970000  0.970000  0.970000   NaN\n",
       "(macro)    0.977451  0.974510  0.970229   NaN"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bert.eval()\n",
    "\n",
    "x_test = _tokenize(test['text'].tolist())\n",
    "y_test = test['label']\n",
    "\n",
    "batch_size = 20\n",
    "preds = []\n",
    "\n",
    "with torch.no_grad():\n",
    "    for i in range(0, x_test.shape[0], batch_size):\n",
    "        x = x_test[i : i + batch_size].to(cuda)\n",
    "        out = bert(x, attention_mask=x.ne(tokenizer.pad_token_id).to(int))\n",
    "        preds.extend(out.logits.to('cpu'))\n",
    "\n",
    "import torch.nn.functional as F\n",
    "from sklearn.metrics import precision_recall_fscore_support as calculate_score\n",
    "        \n",
    "predicted_labels = (\n",
    "    encoder.inverse_transform(\n",
    "        F.one_hot(\n",
    "            torch.vstack(preds).softmax(-1).argmax(-1),\n",
    "            num_classes=len(encoder.categories_[0])\n",
    "        )\n",
    "        .numpy()\n",
    "    )\n",
    "    .reshape(-1)\n",
    ")\n",
    "\n",
    "predicted = list(predicted_labels)\n",
    "actual = y_test.tolist()\n",
    "\n",
    "labels = sorted(set(actual) | set(predicted))\n",
    "\n",
    "scores = calculate_score(actual, predicted, labels=labels)\n",
    "\n",
    "scores_df = pd.DataFrame(scores).T\n",
    "scores_df.columns = ['P', 'R', 'F1', '#']\n",
    "scores_df.index = labels\n",
    "scores_df.loc['(micro)'] = calculate_score(actual, predicted, average='micro', labels=labels)\n",
    "scores_df.loc['(macro)'] = calculate_score(actual, predicted, average='macro', labels=labels)\n",
    "\n",
    "scores_df"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
